#!/usr/bin/env python
# -*- coding:utf-8 -*-

"""
爬虫基类
"""

import sys
from datetime import timedelta
from pprint import pprint

import requests
from tornado import httpclient, gen, ioloop, queues

from request_util import get_random_ua
from xawesome_codechecker import timeit
reload(sys)
sys.setdefaultencoding("utf-8")


class BaseCrawler(object):

    def __init__(self, **kwargs):
        logger = kwargs.get('logger')
        if logger:
            self._log = logger.info
            self._exception = logger.exception
            del kwargs['logger']
        else:
            self._exception = self._log = pprint
        self.__dict__ = dict(self.__dict__, **kwargs)
        self._request = requests.Session()
        self._request.headers['User-Agent'] = get_random_ua()

    def get_raw(self, url, timeout=10, times=3):
        if times == 0:
            return None
        try:
            return self._request.get(url, timeout=timeout)
        except Exception, e:
            self._exception(e)
            return self.get_raw(url, timeout=timeout, times=times-1)

    def get(self, url, timeout=10, times=3):
        raw = self.get_raw(url, timeout=timeout, times=times)
        if raw:
            return raw.content
        return None

    def post_raw(self, url, data, headers=None, timeout=10, times=3):
        if times == 0:
            return None
        try:
            if headers:
                headers = dict(self._request.headers, **headers)
                return self._request.post(url, data=data, headers=headers, timeout=timeout)
            return self._request.post(url, data=data, timeout=timeout)
        except Exception, e:
            self._exception(e)
            return self.post_raw(url, data, headers=headers, timeout=timeout, times=times-1)

    def post(self, url, data, headers=None, timeout=10, times=3):
        raw = self.post_raw(url, data, headers=headers, timeout=timeout, times=times)
        if raw:
            return raw.content
        return None


class AsyncCrawler(object):
    def __init__(self, urls, concurrency=10, **kwargs):
        urls.reverse()
        self.urls = urls
        self.concurrency = concurrency
        self._q = queues.Queue()
        self._fetching = set()
        self._fetched = set()
        self.results = []

    def _fetch(self, url):
        kwargs = self.on_request(url)
        return getattr(httpclient.AsyncHTTPClient(), 'fetch')(url, raise_error=False, **kwargs)

    def on_request(self, url):
        """
        此方法的作用是在发送请求之前组装请求体部分，需要子类实现此方法（默认实现为空）。注意，返回的dict中，key和其含义如下：
            - method          ： string, 请求方式，e.g. "GET" or "POST"
            - headers         ： dict, 请求头
            - body            ： string, 请求体
            - user_agent      ： string, String to send as ``User-Agent`` header
            - follow_redirects： bool, 是否自动跳转
            - max_redirects   ： int,  跳转次数
            - connect_timeout ： float, Timeout for initial connection in seconds
            - request_timeout ： float, Timeout for entire request in seconds

            - auth_username： string, Username for HTTP authentication
            - auth_password： string, Password for HTTP authentication
            - auth_mode    ： string, Authentication mode; default is "basic".

            - proxy_host    : string, HTTP proxy hostname.  To use proxies,
                                 ``proxy_host`` and ``proxy_port`` must be set; ``proxy_username`` and
                                 ``proxy_pass`` are optional.  Proxies are currently only supported
                                 with ``curl_httpclient``.
            - proxy_port    : int,    HTTP proxy port
            - proxy_username: string, HTTP proxy username
            - proxy_password: string, HTTP proxy password

            - validate_cert： bool, For HTTPS requests, validate the server's certificate?
            - ca_certs     ： string, filename of CA certificates in PEM format,
                                 or None to use defaults.  See note below when used with
                                 ``curl_httpclient``.
            - client_key   ： string, Filename for client SSL key, if any.
            - client_cert  ： string, Filename for client SSL certificate, if any.
            - ssl_options  ： ssl.SSLContext, `ssl.SSLContext` object for use in
                                 ``simple_httpclient`` (unsupported by ``curl_httpclient``).
                                 Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
                                 and ``client_cert``.

            - if_modified_since  ： Timestamp for ``If-Modified-Since`` header
            - decompress_response： bool, Request a compressed response from the server
                                      and decompress it after downloading. Default is True.
            - use_gzip           ： bool, Deprecated alias for ``decompress_response``

        :arg string url: 即将请求的URL
        :return dict 请求体
        """
        return {}

    def add_url(self, url):
        self.urls.append(url)

    def on_response(self, url, html):
        """
        当获取数据成功的时候回调此方法
        :param string url: 当前结果来源URL
        :param basestring html: URL对应的html相应
        """
        pass

    def handle_response(self, url, response):
        if response.code == 200:
            self.on_response(url, response.body)
        elif response.code == 599:    # retry
            self._fetching.remove(url)
            self._q.put(url)

    @gen.coroutine
    def _get_page(self, url):
        try:
            response = yield self._fetch(url)
        except Exception as e:
            raise gen.Return(e)
        raise gen.Return(response)

    @gen.coroutine
    def _run(self):
        @gen.coroutine
        def fetch_url():
            current_url = yield self._q.get()
            try:
                if current_url in self._fetching:
                    return
                self._fetching.add(current_url)
                response = yield self._get_page(current_url)
                self.handle_response(current_url, response)    # handle reponse
                self._fetched.add(current_url)
                for i in range(self.concurrency):
                    if self.urls:
                        yield self._q.put(self.urls.pop())
            finally:
                self._q.task_done()

        @gen.coroutine
        def worker():
            while True:
                yield fetch_url()

        self._q.put(self.urls.pop())
        for _ in range(self.concurrency):
            worker()

        yield self._q.join(timeout=timedelta(seconds=300000))

        try:
            assert self._fetching == self._fetched
        except AssertionError:
            print(self._fetching-self._fetched)
            print(self._fetched-self._fetching)

    def run(self):
        io_loop = ioloop.IOLoop.current()
        io_loop.run_sync(self._run)


# ################## AsyncCrawler useage ########################################################

'''
class CrawlerDemo(AsyncCrawler):
    def on_request(self, url):
        headers = {
            'User-Agent': 'test from xlzd',
        }
        return {'headers': headers}

    def on_response(self, url, html):
        print url  # , html


@timeit
def old(size):
    import requests
    for page in range(1, size+1):
        url = 'http://xlzd.me/page/%s/' % str(page)
        requests.get(url, headers={
            'User-Agent': 'test from xlzd',
        })
        print url


@timeit
def new(size):
    urls = []
    for page in range(1, size+1):
        urls.append('http://xlzd.me/page/%s/' % str(page))
    s = CrawlerDemo(urls, concurrency=10)
    s.run()

if __name__ == '__main__':
    size = 100
    new(size)
    old(size)
'''
